{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qKlB5QTDIV6S"
      },
      "source": [
        "# Quantization (Sampling)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/quantization_sampling.ipynb)\n",
        "\n",
        "Example on using quantization with Gemma (for inference). For an example of quantization aware training (QAT), see [QAT finetuning](https://gemma-llm.readthedocs.io/en/latest/quantization_aware_training.html) example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TR-L25KVKT_F"
      },
      "outputs": [],
      "source": [
        "!pip install -q gemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I6fEKB1tISVW"
      },
      "outputs": [],
      "source": [
        "# Common imports\n",
        "import os\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import treescope\n",
        "\n",
        "# Gemma imports\n",
        "from gemma import gm\n",
        "from gemma import peft  # Parameter fine-tuning module"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cxGT2XeU4L47"
      },
      "source": [
        "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o4MidM--4L47"
      },
      "outputs": [],
      "source": [
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-kdAZkvOIryQ"
      },
      "source": [
        "## Initializing the model\n",
        "\n",
        "To use Gemma with quantization, simply wrap any Gemma model in `gm.nn.IntWrapper` (in the example for int8 weight inference):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x-BbrzCVIupV"
      },
      "outputs": [],
      "source": [
        "model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(text_only=True), dtype=jnp.int8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hI3Lg07SJff4"
      },
      "source": [
        "Initialize the weights:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1shC1DpiJfsw"
      },
      "outputs": [],
      "source": [
        "token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)\n",
        "\n",
        "params = model.init(jax.random.key(0), token_ids)\n",
        "\n",
        "params = params['params']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGJl5YpKKOf-"
      },
      "source": [
        "\n",
        "\n",
        "```\n",
        "# This is formatted as code\n",
        "```\n",
        "\n",
        "Restore the pre-trained params. We use `peft.quantize` to quantize the checkpoint.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AcO6oBuLKNjb"
      },
      "outputs": [],
      "source": [
        "del params\n",
        "# Load the params from the checkpoint\n",
        "original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)\n",
        "# edit params`\n",
        "params = peft.quantize(original, method='INT8', checkpoint_kernel_key='w')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6jYRSpv2IwIY"
      },
      "source": [
        "## Fine-tuning\n",
        "\n",
        "See our [finetuning guide](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) for more info."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MvsQbQM4I4Cs"
      },
      "source": [
        "## Inference\n",
        "\n",
        "Here's an example of running a single model call:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 542
        },
        "executionInfo": {
          "elapsed": 11145,
          "status": "ok",
          "timestamp": 1741513754866,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "eqU7a4eCI5Wr",
        "outputId": "e21147d0-a206-493e-fff2-0590d5c7053e"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\u003chtml\u003e\n",
              "\u003chead\u003e\u003cmeta charset=\"utf-8\" /\u003e\u003c/head\u003e\n",
              "\u003cbody\u003e\n",
              "    \u003cdiv\u003e            \u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"\u003e\u003c/script\u003e\u003cscript type=\"text/javascript\"\u003eif (window.MathJax \u0026\u0026 window.MathJax.Hub \u0026\u0026 window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\u003c/script\u003e                \u003cscript type=\"text/javascript\"\u003ewindow.PlotlyConfig = {MathJaxConfig: 'local'};\u003c/script\u003e\n",
              "        \u003cscript charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.34.0.min.js\"\u003e\u003c/script\u003e                \u003cdiv id=\"a4eb62b7-b734-44df-96f9-4d7a0a49ebcf\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"\u003e\u003c/div\u003e            \u003cscript type=\"text/javascript\"\u003e                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"a4eb62b7-b734-44df-96f9-4d7a0a49ebcf\")) {                    Plotly.newPlot(                        \"a4eb62b7-b734-44df-96f9-4d7a0a49ebcf\",                        [{\"alignmentgroup\":\"True\",\"hovertemplate\":\"x=%{x}\\u003cbr\\u003ey=%{y}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\",\"legendgroup\":\"\",\"marker\":{\"color\":\"#636efa\",\"pattern\":{\"shape\":\"\"}},\"name\":\"\",\"offsetgroup\":\"\",\"orientation\":\"v\",\"showlegend\":false,\"textposition\":\"auto\",\"x\":[\"' Paris'\",\"' a'\",\"':'\",\"' the'\",\"' located'\",\"' **'\",\"' known'\",\"' '\",\"'...'\",\"'?'\",\"' in'\",\"' called'\",\"' not'\",\"','\",\"' what'\",\"' always'\",\"'\\\\n\\\\n'\",\"' which'\",\"' ____'\",\"' ________'\",\"' often'\",\"' one'\",\"' generally'\",\"' currently'\",\"' where'\",\"'\\u2026'\",\"' considered'\",\"' ____________'\",\"' now'\",\"' being'\"],\"xaxis\":\"x\",\"y\":[0.78515625,0.034423828125,0.023681640625,0.0208740234375,0.01263427734375,0.01123046875,0.01123046875,0.007720947265625,0.007720947265625,0.00531005859375,0.00531005859375,0.004974365234375,0.004119873046875,0.003204345703125,0.003021240234375,0.0028228759765625,0.0028228759765625,0.002655029296875,0.002349853515625,0.002197265625,0.0020751953125,0.0020751953125,0.0019378662109375,0.0018310546875,0.00171661376953125,0.001617431640625,0.00151824951171875,0.00125885009765625,0.00118255615234375,0.00091552734375],\"yaxis\":\"y\",\"type\":\"bar\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"title\":{\"text\":\"Tokens\"}},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"title\":{\"text\":\"Probability\"}},\"legend\":{\"tracegroupgap\":0},\"margin\":{\"t\":60},\"barmode\":\"relative\",\"title\":{\"text\":\"Probability Distribution of Tokens\"}},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('a4eb62b7-b734-44df-96f9-4d7a0a49ebcf');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            \u003c/script\u003e        \u003c/div\u003e\n",
              "\u003c/body\u003e\n",
              "\u003c/html\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "tokenizer = gm.text.Gemma3Tokenizer()\n",
        "\n",
        "prompt = tokenizer.encode('The capital of France is')\n",
        "prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)\n",
        "\n",
        "\n",
        "# Run the model\n",
        "out = model.apply(\n",
        "    {'params': params},\n",
        "    tokens=prompt,\n",
        "    return_last_only=True,  # Only predict the last token\n",
        ")\n",
        "\n",
        "\n",
        "# Show the token distribution\n",
        "tokenizer.plot_logits(out.logits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6dOSL9MHuMUa"
      },
      "source": [
        "To sample an entire sentence:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 10172,
          "status": "ok",
          "timestamp": 1741513765357,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "_ckwREdyqown",
        "outputId": "4f46f48b-4ad5-4a0b-d332-f47690a6f135"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "' Paris.\\n\\nParis is a global center for art, fashion, gastronomy, and culture. It is known for its iconic landmarks such as the Eiffel Tower'"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sampler = gm.text.Sampler(\n",
        "    model=model,\n",
        "    params=params,\n",
        "    tokenizer=tokenizer,\n",
        ")\n",
        "\n",
        "sampler.sample('The capital of France is', max_new_tokens=30)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {},
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
